Make onnx export SDPA match aten behavior#2479
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2479 +/- ##
==========================================
- Coverage 69.81% 69.81% -0.01%
==========================================
Files 209 209
Lines 25313 25314 +1
Branches 2525 2525
==========================================
Hits 17673 17673
- Misses 6762 6763 +1
Partials 878 878 ☔ View full report in Codecov by Sentry. |
|
Do we need to update https://github.com/pytorch/pytorch/blob/main/torch/onnx/_internal/exporter/_torchlib/ops/nn.py as well, or improve specs of the Attention op? @gramalingam @titaiwangms |
| # This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match | ||
| # the behavior of PyTorch with boolean masks. | ||
| attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight) | ||
| attn_weight, _ = op.Dropout(attn_weight, dropout_p) |
There was a problem hiding this comment.
@titaiwangms we should probably conditionally skip this line (even though there is a rewrite rule already)
There was a problem hiding this comment.
If you fix this, can you also please add a reference to pytorch/pytorch#103749 in the comments for the previous line fixing NaN?
There was a problem hiding this comment.
We skip when dropout_p is 0?
This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used.
fails the assertion because the ort model outputs nans.